#HIDE
# @title Loading libraries and setting up drive
%%capture
from google.colab import drive
drive.mount('/content/drive')
%cd '/content/drive/MyDrive/Colab Notebooks/Explainable-AI'
# Import the necessary packages
import io
import pandas as pd
import numpy as np
import os
import glob
import random
import PIL
import seaborn as sns
import pickle
from PIL import *
import cv2
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.applications import DenseNet121
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.initializers import glorot_uniform
from tensorflow.keras.utils import plot_model, to_categorical
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint, LearningRateScheduler, Callback, CSVLogger
from IPython.display import display
from tensorflow.keras import layers, optimizers
from tensorflow.keras.initializers import glorot_uniform
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.preprocessing import StandardScaler, normalize
from tensorflow.keras import layers, optimizers
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.applications import DenseNet121
from tensorflow.keras.layers import *
from tensorflow.keras import backend as K
from keras import optimizers
import matplotlib.pyplot as plt
import json
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from google.colab.patches import cv2_imshow
import copy
from graphviz import Digraph
from IPython.display import Javascript, HTML, Image, display, SVG
import copy
from scipy.ndimage import rotate
import requests
import plotly.graph_objects as go
from sklearn.metrics import accuracy_score
import tensorflow as tf
from tensorflow.keras import backend as K
import matplotlib.image as mpimg
from pathlib import Path
import plotly.io as pio
pio.renderers.default = "notebook_connected"
retrain_model = False # ← Change to True when you want to retrain
rerun_example = False # ← Change to True when you want to rerun the example
push_git = False # <- Change this to True when you wan to push the changes
image_base = Path('/content/drive/MyDrive/Colab Notebooks/Explainable-AI/docs/pics')
Explainable AI¶
Nutshell¶
In this project I use DataRobot to predict the type of food from images, as explained on the course Modern Artificial Intelligence, lectured by Dr. Ryan Ahmed, Ph.D. MBA. DataRobot is an end-to-end enterprise AI platform that automates and accelerates every step from data to value.
Data¶
The original dataset from https://www.kaggle.com/vermaavi/food11 consists of 16643 color images belonging to 11 categories. Due to data limitations I will use pictures from 4 classes only:
- Dessert
- Seafood
- Fried food
- Vegetable-Fruit
Grad-CAM visualization¶
Gradient-Weighted Class Activation Mapping (Grad-CAM) makes it possible to visualize the regions of the input that contributed towards making prediction by the model. It does so by using the class- specific gradient information flowing into the final convolutional layer of CNN to localize the important regions in the image that resulted in predicting that particular class.
Steps¶
- To visualize the activation maps, first the image has to be passed through the model to make the prediction. Using argmax find the index corresponding to the maximum value in the prediction - this is the predicted class.
- Next, the gradient that is used to arrive to the predicted class from the feature map activations A^k is calculated.
$$ \alpha = \frac{1}{Z} \sum_i \sum_j \frac{\partial y^c}{\partial A^k_{ij}}$$ where $$\frac{1}{Z} \sum_i = \text{global average pooling}$$ $$\frac{\partial y^c}{\partial A^k_{ij}} = \text{gradients via backdrop}$$
- To enhance the filter values that resulted in this prediction, the values are multiplied with tensorflow.GradientTape() with the filter values in the last convolutional layer.
- This enhances the filter values that contributed towards making this prediction and lower the filter values that didn't contribute.
- Then, the weighted combination of activation maps is performed and followed by ReLU to obtain the heatmap.
$$ L^c_{Grad-CAM}= ReLU(\sum \alpha^c_kA^k) $$ where
Finally, the feature heatmap is super-imposed on the original image to see the activation locations in the image.
#HIDE
palette = [
"#c7522a", "#e5c185", "#f0daa5", "#fbf2c4",
"#b8cdab", "#74a892", "#008585", "#004343"
]
grey = "#e9ecef"
kp_color="#c7522a"
palette_hex = {"InputLayer": "#ffaa00",
"ZeroPadding2D":"#e9854f",
"Conv2D": "#8ECAE6",
"BatchNormalization": "#219EBC",
"Activation": "#023047",
"ReLU": "#993461",
"Add": "#126782",
"MaxPooling2D": "#f09135",
"AveragePooling2D": "#f09135",
"Flatten": "#bd5665",
"Dense": "#993461",
"Dropout": "#692161"}
#HIDE
import matplotlib as mpl
grey = "#e9ecef"
legend_bg = "#272b30"
mpl.rcParams.update({
# text
"text.color": grey,
"axes.labelcolor": grey,
"axes.titlecolor": grey,
"xtick.color": grey,
"ytick.color": grey,
# axes / spines / grid
"axes.edgecolor": grey,
"grid.color": grey,
# legend
"legend.labelcolor": grey,
"legend.facecolor": legend_bg,
"legend.edgecolor": "none",
# Transparent backgrounds everywhere
"figure.facecolor": "none",
"axes.facecolor": "none",
"savefig.facecolor": "none",
"savefig.transparent": True,
})
The best performing model for this task was Regularized Logistic Regression (L2). $$ \begin{array}{lccc} \hline \textbf{Metric} & \textbf{Validation} & \textbf{Cross-validation} & \textbf{Holdout} \\ \hline \text{AUC} & 0.9788 & 0.9877 & 0.9822 \\ \text{Accuracy} & 0.8900 & 0.9188 & 0.9036 \\ \text{Balanced Accuracy} & 0.8885 & 0.9196 & 0.9045 \\ \text{FVE Multinomial} & 0.7699 & 0.8202 & 0.7815 \\ \text{LogLoss} & 0.3188 & 0.2475 & 0.3019 \\ \hline \end{array} $$
img_path=image_base / "Regularized_Logistic_Regression_Confusion_Matrix.png"
display(Image(filename=img_path, width=560))
Below is an example from the DataRobot models attention maps.
paths = [
image_base / "seafood.jpg",
image_base / "seafood_attention_map.png",
image_base / "seafood_heat_map.png",
]
fig, axes = plt.subplots(1, 3, figsize=(10,5))
for ax, p in zip(axes, paths):
ax.imshow(mpimg.imread(p))
ax.set_title(p.name)
ax.axis("off")
plt.tight_layout()
plt.show()
#HIDE
%%capture
%cd '/content/drive/MyDrive/Colab Notebooks/Explainable-AI/Data'
Applying a Grad-CAM for the Brain tumor detector classifier model¶
Next I will implement a Grad-CAM pipeline for the model built in my other project. This model takes MRI images of the brain as input and classifies them intotwo classes: contains a brain tumor or not. You can check the project here: Brain tumor detector
#HIDE
%%capture
%cd '/content/drive/MyDrive/Colab Notebooks/brain-tumor-detector/Brain_MRI'
#HIDE
# @title MRI scan
brain_df = pd.read_csv("data_mask.csv")
img_bgr = cv2.imread(brain_df.image_path[624])
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
fig, ax = plt.subplots(figsize=(4, 4))
fig.patch.set_alpha(0) # transparent figure bg
ax.set_facecolor("none") # transparent axes bg
ax.imshow(img_rgb)
ax.set_title("MRI", color="#e9ecef")
ax.axis("off")
plt.show()
Introduction to the Brain Tumor Detection¶
Deep learning has proven to be as good and even better than humans in detecting diseases from X-rays, MRI scans and CT scans. there is huge potential in using AI to speed up and improve the accuracy of diagnosis. This project will use the labeled dataset from https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation which consists of 3929 Brain MRI scans and the tumor location. The final pipeline has a two step process where
- A Resnet deep learning classifier model will classify the input images into two groups: tumor detected and tumor not detected.
- For the images, where tumor was detected, a second step is performed, where a ResUNet segmentation model detects the tumor location on the pixel level.
Below is an exmaple of an MRI image and the matching mask. This example has a small tumor. In images where no tumor is present, the mask will be complety black.
#HIDE
mask = brain_df.mask_path[623]
img = brain_df.image_path[623]
fig, axes = plt.subplots(1, 2, figsize=(6, 3))
fig.patch.set_alpha(0) # transparent figure bg
ax.set_facecolor("none") # transparent axes bg
axes[0].imshow(cv2.imread(img))
axes[0].set_title("MRI Image", color="#e9ecef")
axes[0].axis("off")
axes[1].imshow(cv2.imread(mask), cmap="gray")
axes[1].set_title("Mask", color="#e9ecef")
axes[1].axis("off")
plt.tight_layout()
plt.show()
Convolutional neural networks (CNNs)¶
- The first CNN layers are used to extract high level general features
- The last couple of layers will perform classification
- Locla respective fields scan the image first searching for simple shapes such as edges and lines
- The edges are picked up by the subsequent layer to form more complex features
A good visualisation of the feature extraction with convolutions can be found at https://setosa.io/ev/image-kernels/
ResNet (Residual Network)¶
- As CNNs grow deeper, vanishing gradients negatively imapct the network performance. Vanishing gradient occurs when the gradient is backpropagated to earlier layers which results in a very small gradient.
- ResNets "skip connection" feature can allow training of 152 layers wihtout vanishing gradient problems
- ResNet adds "identity mapping on top of the CNN
- ResNet deep network is trained with ImageNet, which contains 11 million images and 11 000 categories
ResNet paper (He etal, 2015): https://arxiv.org/pdf/1512.03385
As seen in the Figure 6. from the Resnet paper, the ResNet architectures overcome the training challenges from deep networks compared ot the plain networks. ResNet-152 achieved 3.58% error rate on the ImageNet dataset. This is better than human performance.
#HIDE
image_path = image_base / 'resnetwork.png'
display(Image(filename=image_path))
#HIDE
%%capture
# @title Train test split
#Drop the patient_id (we don't need it)
brain_df_train = brain_df.drop(columns = ['patient_id'])
brain_df_train.head(0)
brain_df_train.info()
#convert the data in mask column into a string format, to use categorical mode
#in flow_fom_dataframe. Otherwise we get TypeError
brain_df_train['mask'] = brain_df_train['mask'].astype(str)
brain_df_train.info()
train, test = train_test_split(brain_df_train, test_size=0.15)
#HIDE
# @title Image generator
#create a data generator which scales the data from 0 to 1 and makes validation
#dividing with 255 normalises the vaues which are between 0 and 255
#split of 0.15
datagen = ImageDataGenerator(rescale=1./255, validation_split=0.15)
# @title Preparing image generators
train_generator = datagen.flow_from_dataframe(
dataframe = train,
directory = './',
x_col = 'image_path',
y_col = 'mask',
subset = 'training',
batch_size =16,
shuffle = True,
class_mode = 'categorical',
target_size = (256, 256)
)
valid_generator = datagen.flow_from_dataframe(
dataframe = train,
directory = './',
x_col = 'image_path',
y_col = 'mask',
subset = 'validation',
batch_size = 16,
shuffle = True,
class_mode = 'categorical',
target_size = (256, 256)
)
#create a data generator for test images
#no need for splitting again because here we use the "test" data set
test_datagen = ImageDataGenerator(rescale=1./255)
test_generator = datagen.flow_from_dataframe(
dataframe = test,
directory = './',
x_col = 'image_path',
y_col = 'mask',
batch_size = 16,
shuffle = False,
class_mode = 'categorical',
target_size = (256, 256)
)
Found 2839 validated image filenames belonging to 2 classes. Found 500 validated image filenames belonging to 2 classes. Found 590 validated image filenames belonging to 2 classes.
#HIDE
#load the trained modle
%cd '/content/drive/MyDrive/Colab Notebooks/brain-tumor-detector/Brain_MRI'
# 1) Load the saved model
model = load_model("classifier-resnet-weights.keras")
# (Optional) verify it’s really loaded
model.summary()
Model: "functional"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ │ input_layer │ (None, 256, 256, │ 0 │ - │ │ (InputLayer) │ 3) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv1_pad │ (None, 262, 262, │ 0 │ input_layer[0][0] │ │ (ZeroPadding2D) │ 3) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv1_conv (Conv2D) │ (None, 128, 128, │ 9,472 │ conv1_pad[0][0] │ │ │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv1_bn │ (None, 128, 128, │ 256 │ conv1_conv[0][0] │ │ (BatchNormalizatio… │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv1_relu │ (None, 128, 128, │ 0 │ conv1_bn[0][0] │ │ (Activation) │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ pool1_pad │ (None, 130, 130, │ 0 │ conv1_relu[0][0] │ │ (ZeroPadding2D) │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ pool1_pool │ (None, 64, 64, │ 0 │ pool1_pad[0][0] │ │ (MaxPooling2D) │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block1_1_conv │ (None, 64, 64, │ 4,160 │ pool1_pool[0][0] │ │ (Conv2D) │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block1_1_bn │ (None, 64, 64, │ 256 │ conv2_block1_1_c… │ │ (BatchNormalizatio… │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block1_1_relu │ (None, 64, 64, │ 0 │ conv2_block1_1_b… │ │ (Activation) │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block1_2_conv │ (None, 64, 64, │ 36,928 │ conv2_block1_1_r… │ │ (Conv2D) │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block1_2_bn │ (None, 64, 64, │ 256 │ conv2_block1_2_c… │ │ (BatchNormalizatio… │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block1_2_relu │ (None, 64, 64, │ 0 │ conv2_block1_2_b… │ │ (Activation) │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block1_0_conv │ (None, 64, 64, │ 16,640 │ pool1_pool[0][0] │ │ (Conv2D) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block1_3_conv │ (None, 64, 64, │ 16,640 │ conv2_block1_2_r… │ │ (Conv2D) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block1_0_bn │ (None, 64, 64, │ 1,024 │ conv2_block1_0_c… │ │ (BatchNormalizatio… │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block1_3_bn │ (None, 64, 64, │ 1,024 │ conv2_block1_3_c… │ │ (BatchNormalizatio… │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block1_add │ (None, 64, 64, │ 0 │ conv2_block1_0_b… │ │ (Add) │ 256) │ │ conv2_block1_3_b… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block1_out │ (None, 64, 64, │ 0 │ conv2_block1_add… │ │ (Activation) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block2_1_conv │ (None, 64, 64, │ 16,448 │ conv2_block1_out… │ │ (Conv2D) │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block2_1_bn │ (None, 64, 64, │ 256 │ conv2_block2_1_c… │ │ (BatchNormalizatio… │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block2_1_relu │ (None, 64, 64, │ 0 │ conv2_block2_1_b… │ │ (Activation) │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block2_2_conv │ (None, 64, 64, │ 36,928 │ conv2_block2_1_r… │ │ (Conv2D) │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block2_2_bn │ (None, 64, 64, │ 256 │ conv2_block2_2_c… │ │ (BatchNormalizatio… │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block2_2_relu │ (None, 64, 64, │ 0 │ conv2_block2_2_b… │ │ (Activation) │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block2_3_conv │ (None, 64, 64, │ 16,640 │ conv2_block2_2_r… │ │ (Conv2D) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block2_3_bn │ (None, 64, 64, │ 1,024 │ conv2_block2_3_c… │ │ (BatchNormalizatio… │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block2_add │ (None, 64, 64, │ 0 │ conv2_block1_out… │ │ (Add) │ 256) │ │ conv2_block2_3_b… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block2_out │ (None, 64, 64, │ 0 │ conv2_block2_add… │ │ (Activation) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block3_1_conv │ (None, 64, 64, │ 16,448 │ conv2_block2_out… │ │ (Conv2D) │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block3_1_bn │ (None, 64, 64, │ 256 │ conv2_block3_1_c… │ │ (BatchNormalizatio… │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block3_1_relu │ (None, 64, 64, │ 0 │ conv2_block3_1_b… │ │ (Activation) │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block3_2_conv │ (None, 64, 64, │ 36,928 │ conv2_block3_1_r… │ │ (Conv2D) │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block3_2_bn │ (None, 64, 64, │ 256 │ conv2_block3_2_c… │ │ (BatchNormalizatio… │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block3_2_relu │ (None, 64, 64, │ 0 │ conv2_block3_2_b… │ │ (Activation) │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block3_3_conv │ (None, 64, 64, │ 16,640 │ conv2_block3_2_r… │ │ (Conv2D) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block3_3_bn │ (None, 64, 64, │ 1,024 │ conv2_block3_3_c… │ │ (BatchNormalizatio… │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block3_add │ (None, 64, 64, │ 0 │ conv2_block2_out… │ │ (Add) │ 256) │ │ conv2_block3_3_b… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2_block3_out │ (None, 64, 64, │ 0 │ conv2_block3_add… │ │ (Activation) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block1_1_conv │ (None, 32, 32, │ 32,896 │ conv2_block3_out… │ │ (Conv2D) │ 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block1_1_bn │ (None, 32, 32, │ 512 │ conv3_block1_1_c… │ │ (BatchNormalizatio… │ 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block1_1_relu │ (None, 32, 32, │ 0 │ conv3_block1_1_b… │ │ (Activation) │ 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block1_2_conv │ (None, 32, 32, │ 147,584 │ conv3_block1_1_r… │ │ (Conv2D) │ 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block1_2_bn │ (None, 32, 32, │ 512 │ conv3_block1_2_c… │ │ (BatchNormalizatio… │ 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block1_2_relu │ (None, 32, 32, │ 0 │ conv3_block1_2_b… │ │ (Activation) │ 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block1_0_conv │ (None, 32, 32, │ 131,584 │ conv2_block3_out… │ │ (Conv2D) │ 512) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block1_3_conv │ (None, 32, 32, │ 66,048 │ conv3_block1_2_r… │ │ (Conv2D) │ 512) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block1_0_bn │ (None, 32, 32, │ 2,048 │ conv3_block1_0_c… │ │ (BatchNormalizatio… │ 512) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block1_3_bn │ (None, 32, 32, │ 2,048 │ conv3_block1_3_c… │ │ (BatchNormalizatio… │ 512) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block1_add │ (None, 32, 32, │ 0 │ conv3_block1_0_b… │ │ (Add) │ 512) │ │ conv3_block1_3_b… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block1_out │ (None, 32, 32, │ 0 │ conv3_block1_add… │ │ (Activation) │ 512) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block2_1_conv │ (None, 32, 32, │ 65,664 │ conv3_block1_out… │ │ (Conv2D) │ 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block2_1_bn │ (None, 32, 32, │ 512 │ conv3_block2_1_c… │ │ (BatchNormalizatio… │ 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block2_1_relu │ (None, 32, 32, │ 0 │ conv3_block2_1_b… │ │ (Activation) │ 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block2_2_conv │ (None, 32, 32, │ 147,584 │ conv3_block2_1_r… │ │ (Conv2D) │ 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block2_2_bn │ (None, 32, 32, │ 512 │ conv3_block2_2_c… │ │ (BatchNormalizatio… │ 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block2_2_relu │ (None, 32, 32, │ 0 │ conv3_block2_2_b… │ │ (Activation) │ 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block2_3_conv │ (None, 32, 32, │ 66,048 │ conv3_block2_2_r… │ │ (Conv2D) │ 512) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block2_3_bn │ (None, 32, 32, │ 2,048 │ conv3_block2_3_c… │ │ (BatchNormalizatio… │ 512) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block2_add │ (None, 32, 32, │ 0 │ conv3_block1_out… │ │ (Add) │ 512) │ │ conv3_block2_3_b… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block2_out │ (None, 32, 32, │ 0 │ conv3_block2_add… │ │ (Activation) │ 512) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block3_1_conv │ (None, 32, 32, │ 65,664 │ conv3_block2_out… │ │ (Conv2D) │ 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block3_1_bn │ (None, 32, 32, │ 512 │ conv3_block3_1_c… │ │ (BatchNormalizatio… │ 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block3_1_relu │ (None, 32, 32, │ 0 │ conv3_block3_1_b… │ │ (Activation) │ 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block3_2_conv │ (None, 32, 32, │ 147,584 │ conv3_block3_1_r… │ │ (Conv2D) │ 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block3_2_bn │ (None, 32, 32, │ 512 │ conv3_block3_2_c… │ │ (BatchNormalizatio… │ 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block3_2_relu │ (None, 32, 32, │ 0 │ conv3_block3_2_b… │ │ (Activation) │ 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block3_3_conv │ (None, 32, 32, │ 66,048 │ conv3_block3_2_r… │ │ (Conv2D) │ 512) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block3_3_bn │ (None, 32, 32, │ 2,048 │ conv3_block3_3_c… │ │ (BatchNormalizatio… │ 512) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block3_add │ (None, 32, 32, │ 0 │ conv3_block2_out… │ │ (Add) │ 512) │ │ conv3_block3_3_b… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block3_out │ (None, 32, 32, │ 0 │ conv3_block3_add… │ │ (Activation) │ 512) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block4_1_conv │ (None, 32, 32, │ 65,664 │ conv3_block3_out… │ │ (Conv2D) │ 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block4_1_bn │ (None, 32, 32, │ 512 │ conv3_block4_1_c… │ │ (BatchNormalizatio… │ 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block4_1_relu │ (None, 32, 32, │ 0 │ conv3_block4_1_b… │ │ (Activation) │ 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block4_2_conv │ (None, 32, 32, │ 147,584 │ conv3_block4_1_r… │ │ (Conv2D) │ 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block4_2_bn │ (None, 32, 32, │ 512 │ conv3_block4_2_c… │ │ (BatchNormalizatio… │ 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block4_2_relu │ (None, 32, 32, │ 0 │ conv3_block4_2_b… │ │ (Activation) │ 128) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block4_3_conv │ (None, 32, 32, │ 66,048 │ conv3_block4_2_r… │ │ (Conv2D) │ 512) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block4_3_bn │ (None, 32, 32, │ 2,048 │ conv3_block4_3_c… │ │ (BatchNormalizatio… │ 512) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block4_add │ (None, 32, 32, │ 0 │ conv3_block3_out… │ │ (Add) │ 512) │ │ conv3_block4_3_b… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv3_block4_out │ (None, 32, 32, │ 0 │ conv3_block4_add… │ │ (Activation) │ 512) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block1_1_conv │ (None, 16, 16, │ 131,328 │ conv3_block4_out… │ │ (Conv2D) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block1_1_bn │ (None, 16, 16, │ 1,024 │ conv4_block1_1_c… │ │ (BatchNormalizatio… │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block1_1_relu │ (None, 16, 16, │ 0 │ conv4_block1_1_b… │ │ (Activation) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block1_2_conv │ (None, 16, 16, │ 590,080 │ conv4_block1_1_r… │ │ (Conv2D) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block1_2_bn │ (None, 16, 16, │ 1,024 │ conv4_block1_2_c… │ │ (BatchNormalizatio… │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block1_2_relu │ (None, 16, 16, │ 0 │ conv4_block1_2_b… │ │ (Activation) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block1_0_conv │ (None, 16, 16, │ 525,312 │ conv3_block4_out… │ │ (Conv2D) │ 1024) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block1_3_conv │ (None, 16, 16, │ 263,168 │ conv4_block1_2_r… │ │ (Conv2D) │ 1024) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block1_0_bn │ (None, 16, 16, │ 4,096 │ conv4_block1_0_c… │ │ (BatchNormalizatio… │ 1024) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block1_3_bn │ (None, 16, 16, │ 4,096 │ conv4_block1_3_c… │ │ (BatchNormalizatio… │ 1024) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block1_add │ (None, 16, 16, │ 0 │ conv4_block1_0_b… │ │ (Add) │ 1024) │ │ conv4_block1_3_b… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block1_out │ (None, 16, 16, │ 0 │ conv4_block1_add… │ │ (Activation) │ 1024) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block2_1_conv │ (None, 16, 16, │ 262,400 │ conv4_block1_out… │ │ (Conv2D) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block2_1_bn │ (None, 16, 16, │ 1,024 │ conv4_block2_1_c… │ │ (BatchNormalizatio… │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block2_1_relu │ (None, 16, 16, │ 0 │ conv4_block2_1_b… │ │ (Activation) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block2_2_conv │ (None, 16, 16, │ 590,080 │ conv4_block2_1_r… │ │ (Conv2D) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block2_2_bn │ (None, 16, 16, │ 1,024 │ conv4_block2_2_c… │ │ (BatchNormalizatio… │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block2_2_relu │ (None, 16, 16, │ 0 │ conv4_block2_2_b… │ │ (Activation) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block2_3_conv │ (None, 16, 16, │ 263,168 │ conv4_block2_2_r… │ │ (Conv2D) │ 1024) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block2_3_bn │ (None, 16, 16, │ 4,096 │ conv4_block2_3_c… │ │ (BatchNormalizatio… │ 1024) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block2_add │ (None, 16, 16, │ 0 │ conv4_block1_out… │ │ (Add) │ 1024) │ │ conv4_block2_3_b… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block2_out │ (None, 16, 16, │ 0 │ conv4_block2_add… │ │ (Activation) │ 1024) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block3_1_conv │ (None, 16, 16, │ 262,400 │ conv4_block2_out… │ │ (Conv2D) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block3_1_bn │ (None, 16, 16, │ 1,024 │ conv4_block3_1_c… │ │ (BatchNormalizatio… │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block3_1_relu │ (None, 16, 16, │ 0 │ conv4_block3_1_b… │ │ (Activation) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block3_2_conv │ (None, 16, 16, │ 590,080 │ conv4_block3_1_r… │ │ (Conv2D) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block3_2_bn │ (None, 16, 16, │ 1,024 │ conv4_block3_2_c… │ │ (BatchNormalizatio… │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block3_2_relu │ (None, 16, 16, │ 0 │ conv4_block3_2_b… │ │ (Activation) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block3_3_conv │ (None, 16, 16, │ 263,168 │ conv4_block3_2_r… │ │ (Conv2D) │ 1024) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block3_3_bn │ (None, 16, 16, │ 4,096 │ conv4_block3_3_c… │ │ (BatchNormalizatio… │ 1024) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block3_add │ (None, 16, 16, │ 0 │ conv4_block2_out… │ │ (Add) │ 1024) │ │ conv4_block3_3_b… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block3_out │ (None, 16, 16, │ 0 │ conv4_block3_add… │ │ (Activation) │ 1024) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block4_1_conv │ (None, 16, 16, │ 262,400 │ conv4_block3_out… │ │ (Conv2D) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block4_1_bn │ (None, 16, 16, │ 1,024 │ conv4_block4_1_c… │ │ (BatchNormalizatio… │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block4_1_relu │ (None, 16, 16, │ 0 │ conv4_block4_1_b… │ │ (Activation) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block4_2_conv │ (None, 16, 16, │ 590,080 │ conv4_block4_1_r… │ │ (Conv2D) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block4_2_bn │ (None, 16, 16, │ 1,024 │ conv4_block4_2_c… │ │ (BatchNormalizatio… │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block4_2_relu │ (None, 16, 16, │ 0 │ conv4_block4_2_b… │ │ (Activation) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block4_3_conv │ (None, 16, 16, │ 263,168 │ conv4_block4_2_r… │ │ (Conv2D) │ 1024) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block4_3_bn │ (None, 16, 16, │ 4,096 │ conv4_block4_3_c… │ │ (BatchNormalizatio… │ 1024) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block4_add │ (None, 16, 16, │ 0 │ conv4_block3_out… │ │ (Add) │ 1024) │ │ conv4_block4_3_b… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block4_out │ (None, 16, 16, │ 0 │ conv4_block4_add… │ │ (Activation) │ 1024) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block5_1_conv │ (None, 16, 16, │ 262,400 │ conv4_block4_out… │ │ (Conv2D) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block5_1_bn │ (None, 16, 16, │ 1,024 │ conv4_block5_1_c… │ │ (BatchNormalizatio… │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block5_1_relu │ (None, 16, 16, │ 0 │ conv4_block5_1_b… │ │ (Activation) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block5_2_conv │ (None, 16, 16, │ 590,080 │ conv4_block5_1_r… │ │ (Conv2D) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block5_2_bn │ (None, 16, 16, │ 1,024 │ conv4_block5_2_c… │ │ (BatchNormalizatio… │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block5_2_relu │ (None, 16, 16, │ 0 │ conv4_block5_2_b… │ │ (Activation) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block5_3_conv │ (None, 16, 16, │ 263,168 │ conv4_block5_2_r… │ │ (Conv2D) │ 1024) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block5_3_bn │ (None, 16, 16, │ 4,096 │ conv4_block5_3_c… │ │ (BatchNormalizatio… │ 1024) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block5_add │ (None, 16, 16, │ 0 │ conv4_block4_out… │ │ (Add) │ 1024) │ │ conv4_block5_3_b… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block5_out │ (None, 16, 16, │ 0 │ conv4_block5_add… │ │ (Activation) │ 1024) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block6_1_conv │ (None, 16, 16, │ 262,400 │ conv4_block5_out… │ │ (Conv2D) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block6_1_bn │ (None, 16, 16, │ 1,024 │ conv4_block6_1_c… │ │ (BatchNormalizatio… │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block6_1_relu │ (None, 16, 16, │ 0 │ conv4_block6_1_b… │ │ (Activation) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block6_2_conv │ (None, 16, 16, │ 590,080 │ conv4_block6_1_r… │ │ (Conv2D) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block6_2_bn │ (None, 16, 16, │ 1,024 │ conv4_block6_2_c… │ │ (BatchNormalizatio… │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block6_2_relu │ (None, 16, 16, │ 0 │ conv4_block6_2_b… │ │ (Activation) │ 256) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block6_3_conv │ (None, 16, 16, │ 263,168 │ conv4_block6_2_r… │ │ (Conv2D) │ 1024) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block6_3_bn │ (None, 16, 16, │ 4,096 │ conv4_block6_3_c… │ │ (BatchNormalizatio… │ 1024) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block6_add │ (None, 16, 16, │ 0 │ conv4_block5_out… │ │ (Add) │ 1024) │ │ conv4_block6_3_b… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv4_block6_out │ (None, 16, 16, │ 0 │ conv4_block6_add… │ │ (Activation) │ 1024) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block1_1_conv │ (None, 8, 8, 512) │ 524,800 │ conv4_block6_out… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block1_1_bn │ (None, 8, 8, 512) │ 2,048 │ conv5_block1_1_c… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block1_1_relu │ (None, 8, 8, 512) │ 0 │ conv5_block1_1_b… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block1_2_conv │ (None, 8, 8, 512) │ 2,359,808 │ conv5_block1_1_r… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block1_2_bn │ (None, 8, 8, 512) │ 2,048 │ conv5_block1_2_c… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block1_2_relu │ (None, 8, 8, 512) │ 0 │ conv5_block1_2_b… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block1_0_conv │ (None, 8, 8, │ 2,099,200 │ conv4_block6_out… │ │ (Conv2D) │ 2048) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block1_3_conv │ (None, 8, 8, │ 1,050,624 │ conv5_block1_2_r… │ │ (Conv2D) │ 2048) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block1_0_bn │ (None, 8, 8, │ 8,192 │ conv5_block1_0_c… │ │ (BatchNormalizatio… │ 2048) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block1_3_bn │ (None, 8, 8, │ 8,192 │ conv5_block1_3_c… │ │ (BatchNormalizatio… │ 2048) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block1_add │ (None, 8, 8, │ 0 │ conv5_block1_0_b… │ │ (Add) │ 2048) │ │ conv5_block1_3_b… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block1_out │ (None, 8, 8, │ 0 │ conv5_block1_add… │ │ (Activation) │ 2048) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block2_1_conv │ (None, 8, 8, 512) │ 1,049,088 │ conv5_block1_out… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block2_1_bn │ (None, 8, 8, 512) │ 2,048 │ conv5_block2_1_c… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block2_1_relu │ (None, 8, 8, 512) │ 0 │ conv5_block2_1_b… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block2_2_conv │ (None, 8, 8, 512) │ 2,359,808 │ conv5_block2_1_r… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block2_2_bn │ (None, 8, 8, 512) │ 2,048 │ conv5_block2_2_c… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block2_2_relu │ (None, 8, 8, 512) │ 0 │ conv5_block2_2_b… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block2_3_conv │ (None, 8, 8, │ 1,050,624 │ conv5_block2_2_r… │ │ (Conv2D) │ 2048) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block2_3_bn │ (None, 8, 8, │ 8,192 │ conv5_block2_3_c… │ │ (BatchNormalizatio… │ 2048) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block2_add │ (None, 8, 8, │ 0 │ conv5_block1_out… │ │ (Add) │ 2048) │ │ conv5_block2_3_b… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block2_out │ (None, 8, 8, │ 0 │ conv5_block2_add… │ │ (Activation) │ 2048) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block3_1_conv │ (None, 8, 8, 512) │ 1,049,088 │ conv5_block2_out… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block3_1_bn │ (None, 8, 8, 512) │ 2,048 │ conv5_block3_1_c… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block3_1_relu │ (None, 8, 8, 512) │ 0 │ conv5_block3_1_b… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block3_2_conv │ (None, 8, 8, 512) │ 2,359,808 │ conv5_block3_1_r… │ │ (Conv2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block3_2_bn │ (None, 8, 8, 512) │ 2,048 │ conv5_block3_2_c… │ │ (BatchNormalizatio… │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block3_2_relu │ (None, 8, 8, 512) │ 0 │ conv5_block3_2_b… │ │ (Activation) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block3_3_conv │ (None, 8, 8, │ 1,050,624 │ conv5_block3_2_r… │ │ (Conv2D) │ 2048) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block3_3_bn │ (None, 8, 8, │ 8,192 │ conv5_block3_3_c… │ │ (BatchNormalizatio… │ 2048) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block3_add │ (None, 8, 8, │ 0 │ conv5_block2_out… │ │ (Add) │ 2048) │ │ conv5_block3_3_b… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv5_block3_out │ (None, 8, 8, │ 0 │ conv5_block3_add… │ │ (Activation) │ 2048) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ average_pooling2d │ (None, 2, 2, │ 0 │ conv5_block3_out… │ │ (AveragePooling2D) │ 2048) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ flatten (Flatten) │ (None, 8192) │ 0 │ average_pooling2… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense (Dense) │ (None, 256) │ 2,097,408 │ flatten[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dropout (Dropout) │ (None, 256) │ 0 │ dense[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_1 (Dense) │ (None, 2) │ 514 │ dropout[0][0] │ └─────────────────────┴───────────────────┴────────────┴───────────────────┘
Total params: 76,950,664 (293.54 MB)
Trainable params: 25,632,514 (97.78 MB)
Non-trainable params: 53,120 (207.50 KB)
Optimizer params: 51,265,030 (195.56 MB)
#HIDE
# 2) Use it for inference
make_predictions = True
if make_predictions:
predictions = model.predict(test_generator)
/usr/local/lib/python3.12/dist-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:121: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored.
37/37 ━━━━━━━━━━━━━━━━━━━━ 213s 5s/step
#HIDE
y_pred = []
for i in predictions:
y_pred.append(str(np.argmax(i)))
y_pred = np.array(y_pred)
y_true = np.asarray(test['mask'])[:len(predictions)]
print('The model accuracy is {:.2f}'.format(accuracy_score(y_true, y_pred)))
y_pred
predictions
array([[7.5552982e-01, 2.4447022e-01],
[1.6777843e-04, 9.9983215e-01],
[2.1096836e-03, 9.9789029e-01],
...,
[9.9252218e-01, 7.4777524e-03],
[2.5677751e-03, 9.9743217e-01],
[1.0000000e+00, 4.2331166e-10]], dtype=float32)
#HIDE
#plot the confusion matrix
%cd '/content/drive/MyDrive/Colab Notebooks/Explainable-AI'
cmat = confusion_matrix(y_true, y_pred)
fig, ax = plt.subplots(figsize=(4, 3.5))
fig.patch.set_alpha(0) # transparent figure background
ax.set_facecolor("none") # transparent axes background
sns.heatmap(cmat, annot=True, fmt="g", ax=ax, cbar=True)
ax.set_title("Confusion Matrix", color="#e9ecef")
ax.set_xlabel("Predicted", color="#e9ecef")
ax.set_ylabel("Actual", color="#e9ecef")
ax.tick_params(colors="#e9ecef")
plt.tight_layout()
plt.show()
#HIDE
import numpy as np
import tensorflow as tf
import cv2
import matplotlib.pyplot as plt
from matplotlib import cm
from pathlib import Path
def _load_img_for_model(image_path, target_size=(256, 256)):
"""Loads an image from disk and formats it as (1,H,W,3) float32 in [0,1]."""
img = cv2.imread(str(image_path))
if img is None:
raise FileNotFoundError(f"Could not read image at: {image_path}")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_resized = cv2.resize(img, target_size, interpolation=cv2.INTER_AREA)
img_array = img_resized.astype(np.float32) / 255.0
return img_resized, np.expand_dims(img_array, axis=0)
def _find_last_conv_layer_name(model):
"""Find the last Conv2D-like layer name (works for most CNNs)."""
for layer in reversed(model.layers):
if isinstance(layer, keras.layers.Conv2D):
return layer.name
# fallback: any layer with 4D output (batch,h,w,channels)
for layer in reversed(model.layers):
try:
shape = layer.output_shape
if isinstance(shape, tuple) and len(shape) == 4:
return layer.name
except Exception:
pass
raise ValueError("Could not automatically find a suitable last conv layer.")
def make_gradcam_attention_map(img_array, model, last_conv_layer_name=None, pred_index=None):
"""
Returns (attention_map_0_1, predicted_class_index, predicted_vector).
attention_map_0_1 is a 2D float array normalized to [0,1].
"""
if last_conv_layer_name is None:
last_conv_layer_name = _find_last_conv_layer_name(model)
# Build a model that maps input -> (last conv activations, model output)
last_conv_layer = model.get_layer(last_conv_layer_name)
grad_model = keras.Model([model.inputs], [last_conv_layer.output, model.output])
with tf.GradientTape() as tape:
conv_outputs, preds = grad_model(img_array, training=False)
# Decide which class to explain:
# - sigmoid/binary: preds shape (1,1) -> use that scalar
# - softmax: preds shape (1,C) -> argmax if pred_index not given
preds_np = preds.numpy()
if pred_index is None:
if preds_np.ndim == 2 and preds_np.shape[1] > 1:
pred_index = int(np.argmax(preds_np[0]))
else:
pred_index = 0 # single logit/sigmoid
if preds_np.ndim == 2 and preds_np.shape[1] > 1:
class_channel = preds[:, pred_index]
else:
class_channel = preds[:, 0]
# Gradient of the class score w.r.t. conv feature map
grads = tape.gradient(class_channel, conv_outputs)
# Global average pooling over spatial dimensions => importance weights
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
# Weight the channels by corresponding pooled gradients
conv_outputs = conv_outputs[0]
attention = tf.reduce_sum(conv_outputs * pooled_grads, axis=-1)
# ReLU and normalize to [0,1]
attention = tf.maximum(attention, 0)
max_val = tf.reduce_max(attention) + tf.keras.backend.epsilon()
attention_map = attention / max_val
return attention_map.numpy(), pred_index, preds_np
def show_gradcam_grid(image_path, model, last_conv_layer_name=None, alpha=0.45, out_path=None):
"""
Displays 3 panels:
- original
- attention map (grayscale)
- heatmap overlay
Optionally saves the figure to out_path (recommended for nbconvert/GitHub Pages).
"""
# Load image
img_rgb, img_array = _load_img_for_model(image_path, target_size=(256, 256))
# Grad-CAM attention map
attn, pred_index, preds_np = make_gradcam_attention_map(
img_array, model, last_conv_layer_name=last_conv_layer_name
)
# Resize attention map to image size
attn_resized = cv2.resize(attn, (img_rgb.shape[1], img_rgb.shape[0]), interpolation=cv2.INTER_CUBIC)
# Build a color heatmap
heatmap = (cm.jet(attn_resized)[:, :, :3] * 255).astype(np.uint8)
# Overlay heatmap onto original
overlay = cv2.addWeighted(img_rgb, 1.0, heatmap, alpha, 0)
# Figure
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
fig.patch.set_alpha(0)
for ax in axs:
ax.set_facecolor("none")
ax.axis("off")
# Titles (use your site-grey)
grey = "#e9ecef"
# prediction text
pred_text = ""
if preds_np.ndim == 2 and preds_np.shape[1] > 1:
pred_text = f"pred class = {pred_index} | probs = {np.round(preds_np[0], 4)}"
else:
pred_text = f"pred score = {float(preds_np[0][0]):.4f}"
axs[0].set_title("Original MRI", color=grey)
axs[0].imshow(img_rgb)
axs[1].set_title("Attention map (Grad-CAM)", color=grey)
axs[1].imshow(attn_resized, cmap="gray")
axs[2].set_title("Heatmap overlay", color=grey)
axs[2].imshow(overlay)
fig.suptitle(pred_text, color=grey, y=1.02, fontsize=10)
fig.tight_layout()
if out_path is not None:
Path(out_path).parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, dpi=200, bbox_inches="tight", transparent=True)
plt.close(fig)
return out_path
else:
plt.show()
return None
def _conv_layer_names(model):
return [l.name for l in model.layers if isinstance(l, keras.layers.Conv2D)]
def make_gradcam_attention_map_abs(img_array, model, last_conv_layer_name=None, pred_index=None):
if last_conv_layer_name is None:
last_conv_layer_name = _find_last_conv_layer_name(model)
last_conv_layer = model.get_layer(last_conv_layer_name)
grad_model = keras.Model([model.inputs], [last_conv_layer.output, model.output])
with tf.GradientTape() as tape:
conv_outputs, preds = grad_model(img_array, training=False)
preds_np = preds.numpy()
if pred_index is None:
if preds_np.ndim == 2 and preds_np.shape[1] > 1:
pred_index = int(np.argmax(preds_np[0]))
else:
pred_index = 0
class_channel = preds[:, pred_index] if (preds_np.ndim == 2 and preds_np.shape[1] > 1) else preds[:, 0]
grads = tape.gradient(class_channel, conv_outputs)
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
conv_outputs = conv_outputs[0]
attention = tf.reduce_sum(conv_outputs * pooled_grads, axis=-1)
# ✅ abs instead of ReLU so you still get structure
attention = tf.abs(attention)
max_val = tf.reduce_max(attention) + tf.keras.backend.epsilon()
attention_map = attention / max_val
return attention_map.numpy(), pred_index, preds_np
def make_gradcam_attention_map_robust(
img_array, model, last_conv_layer_name=None, pred_index=None,
min_std=1e-6, min_max=1e-6
):
# candidate conv layers: requested one first, then from last->first
convs = _conv_layer_names(model)
candidates = []
if last_conv_layer_name is not None:
candidates.append(last_conv_layer_name)
for name in reversed(convs):
if name not in candidates:
candidates.append(name)
# 1) try standard Grad-CAM (ReLU)
for name in candidates:
attn, pi, preds_np = make_gradcam_attention_map(
img_array, model, last_conv_layer_name=name, pred_index=pred_index
)
if np.nanmax(attn) > min_max and np.nanstd(attn) > min_std:
return attn, pi, preds_np, name, False
# 2) fallback: abs Grad-CAM on best candidate
best = candidates[0] if candidates else _find_last_conv_layer_name(model)
attn, pi, preds_np = make_gradcam_attention_map_abs(
img_array, model, last_conv_layer_name=best, pred_index=pred_index
)
return attn, pi, preds_np, best, True
def enhance_cam_for_display(cam, clip_percentiles=(2, 98), gamma=0.7, gain=2.0, eps=1e-8):
"""
Display-only enhancement:
1) percentile clip to remove outliers
2) min-max normalize to [0,1]
3) gamma to boost midtones
4) gain to make faint maps visible
"""
cam = np.nan_to_num(cam, nan=0.0, posinf=0.0, neginf=0.0)
lo, hi = np.percentile(cam, clip_percentiles)
cam = np.clip(cam, lo, hi)
cam = (cam - lo) / (hi - lo + eps)
cam = np.power(cam, gamma)
cam = np.clip(cam * gain, 0, 1)
return cam
Below is an example from the grad-CAM visualisation applied to the brain tumor image classifier model.
#HIDE
# Example: pick first image that your pipeline marked as having a tumor
# Find all indices where y_pred is '1' (predicted as having a tumor)
tumor_predicted_indices = np.where(y_pred == '1')[0]
if len(tumor_predicted_indices) > 0:
# Get the first index from the list of positive predictions
row_for_display = tumor_predicted_indices[0]
# Use this index to get the corresponding image path from the 'test' DataFrame
image_path = test.iloc[row_for_display]['image_path']
else:
# Fallback if no tumors are predicted (unlikely given the confusion matrix, but good practice)
print("Warning: No images were predicted as having a tumor. Displaying Grad-CAM for the first image in the test set.")
row_for_display = 0
image_path = test.iloc[row_for_display]['image_path']
out_path = "docs/pics/gradcam_triptych.png"
saved = show_gradcam_grid(image_path, model, out_path=out_path)
from IPython.display import Image, display
display(Image(filename=saved, width=900))
In the grid below we see that the attention map is corresponding well with the clinicians mask - eventhough the classifier has never seen or beentrained on the mask.
This block loops over a set of test images (here: cases predicted as containing tumor), loading each MRI slice and its ground-truth mask (if available), and then computing a Grad-CAM attention map for the classifier’s decision. The Grad-CAM attention map is contrast-enhanced for visibility. I chose to look further “in” than the very last convolutional layer (via the robust layer selection) because the final conv features can become too coarse or gradient-saturated, producing weak or even blank maps; earlier convolutional layers often preserve more spatial detail and yield attention maps that align better with the actual tumor region, making the explanation more informative and easier to interpret.
#HIDE
# @title Helper function for plotting the architecture
!apt-get -qq install -y graphviz
!pip -q install pydot
import re
import subprocess
from pathlib import Path
from typing import Dict, Optional, Tuple, List
from tensorflow.keras.utils import model_to_dot
from IPython.display import SVG, display
def export_colored_model_graph(
model,
image_base,
stem: str = "model_arch",
palette_hex: Optional[List[str]] = None,
type_color_map: Optional[Dict[str, str]] = None,
show_shapes: bool = True,
show_layer_names: bool = True,
rankdir: str = "TB",
size: str = "30,18",
nodesep: str = "0.15",
ranksep: str = "0.25",
splines: str = "polyline",
edge_color: str = "#000000",
node_border_color: str = "#202020",
bg_color: str = "white",
border_width: int = 1,
display_svg: bool = True,
return_mapping: bool = True,
) -> Tuple[Path, Path, Path, Optional[Dict[str, str]]]:
"""
Exports a Keras model graph with colored nodes (stable by layer TYPE in parentheses).
- Stable mapping: alphabetical layer types -> palette cycle (or use type_color_map to pin types).
- Fixes Keras HTML labels by rewriting bgcolor/font only where needed.
- Forces consistent border color via <table color="..."> (does NOT override font colors).
- Forces all graph edges to edge_color.
"""
if palette_hex is None:
palette_hex = ["#8ecae6","#219ebc","#126782","#023047","#ffb703","#fd9e02","#fb8500"]
out_dir = Path(image_base).resolve()
out_dir.mkdir(parents=True, exist_ok=True)
dot_path = out_dir / f"{stem}.dot"
svg_path = out_dir / f"{stem}.svg"
png_path = out_dir / f"{stem}.png"
dot = model_to_dot(
model,
show_shapes=show_shapes,
show_layer_names=show_layer_names,
rankdir=rankdir,
expand_nested=False,
)
dot.set_graph_defaults(
splines=splines,
outputorder="edgesfirst",
concentrate="false",
ranksep=ranksep,
nodesep=nodesep,
pad="0.2",
margin="0.2",
ratio="compress",
size=size,
)
# Node border set here is mostly irrelevant (HTML label draws the box),
# but doesn't hurt:
dot.set_node_defaults(fontsize="8", margin="0.02", penwidth=str(max(1, border_width)))
# Make ALL edges black (and enforce again below on each edge)
dot.set_edge_defaults(arrowsize="0.4", penwidth="1.1", color=edge_color)
# ---- helpers ----
def hex_luma(h):
h = h.lstrip("#")
r, g, b = int(h[0:2], 16)/255, int(h[2:4], 16)/255, int(h[4:6], 16)/255
return 0.2126*r + 0.7152*g + 0.0722*b
def best_text_color(fill_hex):
return "white" if hex_luma(fill_hex) < 0.55 else "black"
# Keras label contains: "<b>layer_name</b> (LayerType)"
rx = re.compile(r"<b>(?P<lname>[^<]+)</b>\s*\((?P<ltype>[^)]+)\)")
# Collect all types present
types_found = []
for node in dot.get_nodes():
lab = node.get("label")
if not lab:
continue
m = rx.search(lab)
if m:
types_found.append(m.group("ltype"))
types_found = sorted(set(types_found))
# Stable mapping type->color
if type_color_map is not None:
type_to_color = dict(type_color_map)
remaining = [t for t in types_found if t not in type_to_color]
for i, t in enumerate(remaining):
type_to_color[t] = palette_hex[i % len(palette_hex)]
else:
type_to_color = {t: palette_hex[i % len(palette_hex)] for i, t in enumerate(types_found)}
def _fix_table_tag(table_tag: str) -> str:
"""Ensure the <table ...> has consistent border/grid color and border widths."""
tag = table_tag
# Force border + cellborder to be present and consistent
if re.search(r'\bborder="', tag):
tag = re.sub(r'\bborder="[^"]*"', f'border="{border_width}"', tag)
else:
tag = tag[:-1] + f' border="{border_width}">'
if re.search(r'\bcellborder="', tag):
tag = re.sub(r'\bcellborder="[^"]*"', f'cellborder="{border_width}"', tag)
else:
tag = tag[:-1] + f' cellborder="{border_width}">'
# Force table border/grid color (THIS affects table lines, not text fonts)
if re.search(r'\bcolor="', tag):
tag = re.sub(r'\bcolor="[^"]*"', f'color="{node_border_color}"', tag)
else:
tag = tag[:-1] + f' color="{node_border_color}">'
return tag
def recolor_html_label(label: str, header_fill: str) -> str:
"""
Rewrite Keras HTML label:
- Replace ALL bgcolor="black" -> bgcolor=bg_color (so no leftover black fills)
- Recolor ONLY the header cell background
- Set ONLY the header font color for contrast
- Force ONLY the table border/grid color (via <table color="...">)
"""
# 1) Remove all black cell fills (Keras uses black by default)
lab2 = label.replace('bgcolor="black"', f'bgcolor="{bg_color}"')
# 2) Recolor header cell background (the td with colspan="2")
lab2 = re.sub(
r'(<td[^>]*colspan="2"[^>]*bgcolor=")([^"]*)(")',
rf'\1{header_fill}\3',
lab2,
count=1
)
# 3) Set ONLY the header font color (first <font ... color="..."> is usually header)
header_text = best_text_color(header_fill)
lab2 = re.sub(
r'(<font[^>]*color=")([^"]+)(")',
rf'\1{header_text}\3',
lab2,
count=1
)
# 4) Force consistent table border/grid styling by editing only the <table ...> tag
lab2 = re.sub(
r"<table[^>]*>",
lambda m: _fix_table_tag(m.group(0)),
lab2,
count=1
)
return lab2
# Apply node label rewrites
for node in dot.get_nodes():
lab = node.get("label")
if not lab:
continue
m = rx.search(lab)
if not m:
continue
ltype = m.group("ltype")
fill = type_to_color.get(ltype, palette_hex[0])
node.set_style("filled,rounded")
node.set_shape("box")
node.set_color(node_border_color) # outer node outline (if used)
node.set_fontcolor(node_border_color)
node.set("label", recolor_html_label(lab, fill))
# Enforce all existing edges black as well
for e in dot.get_edges():
e.set_color(edge_color)
# Write DOT
dot_path.write_text(dot.to_string())
# Render with cairo to avoid cropped SVG
tmp_svg = out_dir / f"{stem}__tmp.svg"
tmp_png = out_dir / f"{stem}__tmp.png"
if tmp_svg.exists(): tmp_svg.unlink()
if tmp_png.exists(): tmp_png.unlink()
res = subprocess.run(["dot", "-Tsvg:cairo", str(dot_path), "-o", str(tmp_svg)],
capture_output=True, text=True)
if res.returncode != 0:
raise RuntimeError(f"Graphviz SVG render failed:\n{res.stderr}")
subprocess.check_call(["dot", "-Tpng", str(dot_path), "-o", str(tmp_png)])
tmp_svg.replace(svg_path)
tmp_png.replace(png_path)
if display_svg:
svg_text = svg_path.read_text(encoding="utf-8")
display(HTML(f"""
<style>
.svg-scroll {{
max-width: 100%;
overflow-x: auto;
overflow-y: hidden;
-webkit-overflow-scrolling: touch;
touch-action: pan-x;
}}
.svg-scroll .inner {{
width: max-content; /* makes content keep its intrinsic width */
}}
.svg-scroll svg {{
max-width: none !important; /* prevents “shrink-to-fit” */
height: auto !important;
display: block;
}}
</style>
<div class="svg-scroll">
<div class="inner">
{svg_text}
</div>
</div>
"""))
return dot_path, svg_path, png_path, (type_to_color if return_mapping else None)
dot_path, svg_path, png_path, type_to_color = export_colored_model_graph(
model=model,
image_base=image_base,
stem="model_arch",
type_color_map= palette_hex,
show_shapes=True,
rankdir="TB",
edge_color="#8ECAE6",
bg_color="white",
border_width=1,
display_svg=True,
)
#HIDE
# @title Functions for Grad-CAM
import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from matplotlib import cm
from IPython.display import Image, display
def show_gradcam_grid_with_masks(
test,
model,
y_pred,
n_cases=8,
n_cols=2,
grad_cam_setup = 'robust',
enhance=True,
out_path=image_base/"gradcam_grid_with_masks.png",
last_conv_layer_name=None,
alpha=0.45,
):
# pick a ground-truth column for mask presence (0/1)
if "mask" in test.columns:
true_col = "mask"
elif "mask_x" in test.columns:
true_col = "mask_x"
elif "mask_y" in test.columns:
true_col = "mask_y"
elif "has_mask" in test.columns:
true_col = "has_mask"
else:
true_col = None # fallback
# indices (keep your original behavior: select predicted tumors)
tumor_predicted_indices = np.where(np.array(y_pred).astype(str) == "1")[0]
if len(tumor_predicted_indices) == 0:
print("Warning: No images predicted as tumor. Using first n_cases images.")
tumor_predicted_indices = np.arange(min(n_cases, len(test)))
chosen = tumor_predicted_indices[:n_cases]
n_cases = len(chosen)
cases_per_row = n_cols
n_rows = int(np.ceil(n_cases / cases_per_row))
fig, axs = plt.subplots(
n_rows, cases_per_row * 4,
figsize=(cases_per_row * 4.2 * 4, n_rows * 4.2),
)
fig.patch.set_alpha(0)
axs = np.atleast_2d(axs)
grey = "#e9ecef"
for idx_in_grid, row_for_display in enumerate(chosen):
r = idx_in_grid // cases_per_row
c0 = (idx_in_grid % cases_per_row) * 4
image_path = test.iloc[row_for_display]["image_path"]
# ---- predicted / true info ----
pred_val = 1 if str(y_pred[row_for_display]) == "1" else 0
pred_text = "Tumor" if pred_val == 1 else "No tumor"
if true_col is not None:
true_val = int(test.iloc[row_for_display][true_col])
else:
true_val = -1 # unknown
# --- load original image (RGB) ---
img_bgr = cv2.imread(str(image_path))
if img_bgr is None:
for k in range(4):
axs[r, c0 + k].axis("off")
axs[r, c0 + k].set_title("Missing image", color=grey)
continue
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
img_rgb_256 = cv2.resize(img_rgb, (256, 256), interpolation=cv2.INTER_AREA)
# --- load original mask image if available (grayscale) ---
mask_original_256 = None
if "mask_path" in test.columns and pd.notna(test.iloc[row_for_display].get("mask_path", None)):
mask_path = test.iloc[row_for_display]["mask_path"]
mask_original = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
if mask_original is not None:
mask_original_256 = cv2.resize(mask_original, (256, 256), interpolation=cv2.INTER_NEAREST)
# --- grad-cam attention ---
_, img_array = _load_img_for_model(image_path, target_size=(256, 256))
if grad_cam_setup == 'robust':
attn, pred_index, preds_np, used_layer, used_abs = make_gradcam_attention_map_robust(
img_array, model, last_conv_layer_name=last_conv_layer_name
)
if grad_cam_setup == 'abs':
attn, pred_index, preds_np = make_gradcam_attention_map_abs(
img_array, model, last_conv_layer_name=last_conv_layer_name
)
used_layer = last_conv_layer_name
used_abs = True
if grad_cam_setup == 'regular':
attn, pred_index, preds_np = make_gradcam_attention_map(
img_array, model, last_conv_layer_name=last_conv_layer_name
)
used_layer = last_conv_layer_name
used_abs = False
attn_resized = cv2.resize(attn, (256, 256), interpolation=cv2.INTER_CUBIC)
# --- display-only normalization so faint maps aren't black ---
if enhance == True:
attn_disp = enhance_cam_for_display(attn_resized, clip_percentiles=(2, 98), gamma=0.7, gain=2.5)
else:
attn_disp = attn_resized
# --- build a color heatmap ---
heatmap = (cm.jet(attn_resized)[:, :, :3] * 255).astype(np.uint8)
overlay = cv2.addWeighted(img_rgb_256, 1.0, heatmap, alpha, 0)
# --- plot panels ---
# 1) MRI + predicted/true label info
ax = axs[r, c0 + 0]
ax.set_facecolor("none")
ax.imshow(img_rgb_256)
ax.set_title(
f"Case {idx_in_grid+1}\nPred: {pred_text} ({pred_val}) | True mask: {true_val}",
color=grey, fontsize=10
)
ax.axis("off")
# 2) original mask image (if exists)
ax = axs[r, c0 + 1]
ax.set_facecolor("none")
if mask_original_256 is None:
ax.text(0.5, 0.5, "No mask image\navailable", ha="center", va="center", color=grey, fontsize=10)
else:
ax.imshow(mask_original_256, cmap="gray")
ax.set_title("Original Mask Image", color=grey, fontsize=10)
ax.axis("off")
# 3) attention map
ax = axs[r, c0 + 2]
ax.set_facecolor("none")
ax.imshow(attn_disp, cmap="gray", vmin=0, vmax=1)
note = "abs" if used_abs else "relu"
axs[r, c0 + 2].set_title(f"Grad-CAM\n({note}, {used_layer})", color=grey, fontsize=10)
ax.axis("off")
# 4) overlay
ax = axs[r, c0 + 3]
ax.set_facecolor("none")
ax.imshow(overlay)
if preds_np.ndim == 2 and preds_np.shape[1] > 1:
pred_info = f"class={pred_index}"
else:
pred_info = f"score={float(preds_np[0][0]):.3f}"
ax.set_title(f"Overlay\n({pred_info})", color=grey, fontsize=10)
ax.axis("off")
# hide any unused axes
total_axes = n_rows * cases_per_row * 4
used_axes = n_cases * 4
if used_axes < total_axes:
flat = axs.ravel()
for j in range(used_axes, total_axes):
flat[j].axis("off")
flat[j].set_facecolor("none")
fig.tight_layout()
Path(out_path).parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, dpi=200, bbox_inches="tight", transparent=True)
plt.close(fig)
return out_path
conv_l = "conv5_block3_2_conv"
# call
out_path = show_gradcam_grid_with_masks(
test=test,
model=model,
y_pred=y_pred,
n_cases=8,
n_cols=1,
grad_cam_setup='regular',
enhance=False,
last_conv_layer_name = conv_l,
out_path=image_base/"gradcam_grid_with_masks.png",
)
display(Image(filename=out_path, width=1100))
conv_l = "conv5_block3_out"
# call
out_path = show_gradcam_grid_with_masks(
test=test,
model=model,
y_pred=y_pred,
n_cases=8,
n_cols=1,
grad_cam_setup='regular',
enhance=False,
last_conv_layer_name = conv_l,
out_path=image_base/"gradcam_grid_with_masks.png",
)
display(Image(filename=out_path, width=1100))
Output hidden; open in https://colab.research.google.com to view.
conv_l = "conv4_block6_out"
# call
out_path = show_gradcam_grid_with_masks(
test=test,
model=model,
y_pred=y_pred,
n_cases=8,
n_cols=1,
grad_cam_setup='regular',
enhance=False,
last_conv_layer_name = conv_l,
out_path=image_base/"gradcam_grid_with_masks.png",
)
display(Image(filename=out_path, width=1100))
Output hidden; open in https://colab.research.google.com to view.
conv_l = "conv4_block6_out"
# call
out_path = show_gradcam_grid_with_masks(
test=test,
model=model,
y_pred=y_pred,
n_cases=8,
n_cols=1,
grad_cam_setup='regular',
enhance=True,
last_conv_layer_name = conv_l,
out_path=image_base/"gradcam_grid_with_masks.png",
)
display(Image(filename=out_path, width=1100))
Output hidden; open in https://colab.research.google.com to view.
# call
conv_l = "conv5_block3_3_conv"
out_path = show_gradcam_grid_with_masks(
test=test,
model=model,
y_pred=y_pred,
n_cases=8,
n_cols=1,
grad_cam_setup = 'regular',
enhance=False,
last_conv_layer_name = conv_l,
out_path=image_base/"gradcam_grid_with_masks.png",
)
display(Image(filename=out_path, width=1100))
Output hidden; open in https://colab.research.google.com to view.
conv_l = "conv5_block3_1_conv"
out_path = show_gradcam_grid_with_masks(
test=test,
model=model,
y_pred=y_pred,
n_cases=8,
n_cols=1,
grad_cam_setup = 'regular',
enhance=False,
last_conv_layer_name = conv_l,
out_path=image_base/"gradcam_grid_with_masks.png",
)
display(Image(filename=out_path, width=1100))
Output hidden; open in https://colab.research.google.com to view.
from sklearn.metrics import classification_report
print(classification_report(y_true, y_pred, labels = [0,1]))
precision recall f1-score support
0 0.98 0.98 0.98 383
1 0.97 0.97 0.97 207
micro avg 0.98 0.98 0.98 590
macro avg 0.98 0.98 0.98 590
weighted avg 0.98 0.98 0.98 590
#HIDE
push_git = True
# @title GIT commands
# Add changes
if push_git:
!git add .
#HIDE
!git status -sb
!git add .gitignore
#HIDE
%%capture
# Commit changes and remember to change commit message
if push_git:
!git commit -m "changed chart" #Remember to change commit message
#HIDE
%%capture
# Push changes
if push_git:
with open('/content/drive/MyDrive/tokens/token1.txt', 'r') as f:
token = f.read().strip()
!git push https://github.com/KaisuH/Explainable-AI.git
#HIDE
#for removing the last line in gitignore
#!sed -i '$d' .gitignore
#for printing the gitignore
#!cat .gitignore
#HIDE
make_new_branch = False
if make_new_branch:
# 1. Create a brand-new “orphan” branch (no history)
!git checkout --orphan clean-start
# 2. Stage everything in your current working directory
!git add -A
# 3. Commit it as your one “fresh start” commit
!git commit -m "Fresh start: keep only current work"
# 4. Force-push this new branch to overwrite remote main
!git push https://$token@github.com/KaisuH/Emotion-AI.git clean-start:main --force
# 5. (Optional) Switch back to ‘main’ locally and delete the temp branch
!git checkout main
!git branch -D clean-start
#HIDE
# @title Producing the final file
%%capture
%cd '/content/drive/MyDrive/Colab Notebooks/Explainable-AI'
import nbformat
nb_name = "master.ipynb" # original notebook
out_nb_name = "master_tagged.ipynb" # new notebook with tags
nb = nbformat.read(nb_name, as_version=4)
for cell in nb.cells:
if cell.cell_type == 'code':
# Check if the first line of the cell source is #HIDE
first_line = cell.source.strip().split('\n',1)[0]
if first_line.strip().startswith('#HIDE'):
cell.metadata.setdefault('tags', []).append('hide_input')
nbformat.write(nb, out_nb_name)
#HIDE
%%capture
#Run this code to generate the tagged file
%%writefile hide_code_config.json
{
"TagRemovePreprocessor": {
"enabled": true,
"remove_input_tags": ["hide_input"]
},
"Exporter": {
"exclude_input_prompt": true,
"exclude_output_prompt": true
}
}
#HIDE
%%capture
# 1) Convert notebook -> docs/index.html
!jupyter nbconvert --to html --config hide_code_config.json \
--output "docs/index.html" "master_tagged.ipynb"
import re
from pathlib import Path
# -----------------------------
# Settings
# -----------------------------
TITLE = "Explainable AI"
docs_dir = Path("docs")
html_path = docs_dir / "index.html"
# Assets are NOW in docs/ (same folder as index.html)
bootstrap_file = docs_dir / "bootstrap.min.css"
navbar_file = docs_dir / "navbar.html"
bootstrap_href = "bootstrap.min.css" # relative to docs/index.html
# -----------------------------
# Sanity checks
# -----------------------------
if not html_path.exists():
raise FileNotFoundError("docs/index.html not found (nbconvert failed?)")
if not bootstrap_file.exists():
raise FileNotFoundError("docs/bootstrap.min.css not found")
if not navbar_file.exists():
raise FileNotFoundError("docs/navbar.html not found")
# -----------------------------
# Read generated HTML
# -----------------------------
html = html_path.read_text(encoding="utf-8")
def remove_block(text: str, start_marker: str, end_marker: str) -> str:
return re.sub(re.escape(start_marker) + r".*?" + re.escape(end_marker),
"", text, flags=re.DOTALL)
# -----------------------------
# Title (robust)
# -----------------------------
html = re.sub(r"<title>.*?</title>", f"<title>{TITLE}</title>", html, count=1,
flags=re.IGNORECASE | re.DOTALL)
# -----------------------------
# Bootstrap CSS link (idempotent)
# -----------------------------
BOOT_START = "<!-- BOOTSTRAP LINK START -->"
BOOT_END = "<!-- BOOTSTRAP LINK END -->"
html = remove_block(html, BOOT_START, BOOT_END)
bootstrap_link_block = f"""{BOOT_START}
<link rel="stylesheet" href="{bootstrap_href}">
{BOOT_END}
"""
html = re.sub(r"</head>", bootstrap_link_block + "\n</head>", html, count=1, flags=re.IGNORECASE)
# -----------------------------
# Navbar injection:
# - put ALL <style>...</style> from navbar.html into <head>
# - put ONLY <nav>...</nav> into <body>
# -----------------------------
navbar_raw = navbar_file.read_text(encoding="utf-8")
NAVSTYLE_START = "<!-- NAVBAR STYLES START -->"
NAVSTYLE_END = "<!-- NAVBAR STYLES END -->"
NAV_START = "<!-- NAVBAR START -->"
NAV_END = "<!-- NAVBAR END -->"
html = remove_block(html, NAVSTYLE_START, NAVSTYLE_END)
html = remove_block(html, NAV_START, NAV_END)
# Extract styles
navbar_styles = "\n".join(
re.findall(r"<style[\s\S]*?</style>", navbar_raw, flags=re.IGNORECASE)
).strip()
if navbar_styles:
html = re.sub(
r"</head>",
f"{NAVSTYLE_START}\n{navbar_styles}\n{NAVSTYLE_END}\n</head>",
html,
count=1,
flags=re.IGNORECASE
)
# Extract nav
nav_match = re.search(r"<nav[\s\S]*?</nav>", navbar_raw, flags=re.IGNORECASE)
if not nav_match:
raise ValueError("Could not find a <nav>...</nav> block inside docs/navbar.html")
navbar_nav = nav_match.group(0).strip()
# Insert navbar right after opening <body ...>
html = re.sub(
r"(<body[^>]*>)",
r"\1\n" + f"{NAV_START}\n{navbar_nav}\n{NAV_END}\n",
html,
count=1,
flags=re.IGNORECASE
)
# -----------------------------
# ✅ CONTENT COLUMN WRAPPER (idempotent)
# This centers ONLY the notebook content AFTER the navbar.
# -----------------------------
COL_START = "<!-- CONTENT COLUMN START -->"
COL_END = "<!-- CONTENT COLUMN END -->"
html = remove_block(html, COL_START, COL_END)
# Start wrapper right after NAV_END
if NAV_END in html:
html = html.replace(
NAV_END,
NAV_END + f"\n{COL_START}\n<div class=\"content-column\">\n",
1
)
# Close wrapper right before </body> (so navbar is NOT inside the column)
html = re.sub(
r"</body>",
f"\n</div>\n{COL_END}\n</body>",
html,
count=1,
flags=re.IGNORECASE
)
# -----------------------------
# Dark background + scroll fix (idempotent)
# -----------------------------
DARK_START = "<!-- DARK OVERRIDES START -->"
DARK_END = "<!-- DARK OVERRIDES END -->"
html = remove_block(html, DARK_START, DARK_END)
dark_css = f"""{DARK_START}
<style>
html, body {{
background: linear-gradient(360deg, #3a3f44, #272b30, #1b1f24) !important;
background-attachment: fixed !important;
margin: 0 !important;
padding: 0 !important;
color: #aaa !important;
}}
/* ✅ CENTER COLUMN (affects only what we wrapped) */
.content-column {{
max-width: 1140px !important;
margin: 40px auto 0 auto !important;
padding-left: 15px !important;
padding-right: 15px !important;
}}
#notebook-container,
.container, .container-fluid,
.jp-Notebook, .jp-NotebookPanel, .jp-NotebookPanel-notebook,
.jp-Cell, .jp-Cell-inputWrapper, .jp-Cell-outputWrapper,
.jp-OutputArea, .jp-OutputArea-output,
.jp-RenderedHTMLCommon, .jp-RenderedText {{
background: transparent !important;
color: inherit !important;
}}
/*Make *output* code/text green (prints, tracebacks, etc.) */
.jp-OutputArea pre,
.jp-OutputArea-output pre,
.jp-OutputArea-output,
.jp-OutputArea-output code,
.output pre,
.output code {{color: #66ff66 !important;
}}
/*Mauve hyperlinks everywhere EXCEPT inside the navbar */
a, a:link, a:visited {{
color: #993461 !important;
}}
a:hover, a:focus, a:active {{
color: #993461 !important;
text-decoration: underline;
}}
/* Don't override navbar link colors */
nav a, .navbar a {{
color: inherit !important;
text-decoration: none !important;
}}
.jp-RenderedText pre,
.jp-RenderedHTMLCommon pre,
.jp-Cell .jp-Cell-inputWrapper .CodeMirror-lines,
.jp-Cell .jp-Cell-inputWrapper pre,
pre {{
white-space: pre !important;
overflow-x: auto !important;
max-width: 100% !important;
display: block !important;
}}
</style>
{DARK_END}
"""
html = re.sub(r"</head>", dark_css + "\n</head>", html, count=1, flags=re.IGNORECASE)
# -----------------------------
# Bootstrap JS for navbar toggler (idempotent)
# -----------------------------
JS_START = "<!-- BOOTSTRAP NAVBAR JS START -->"
JS_END = "<!-- BOOTSTRAP NAVBAR JS END -->"
html = remove_block(html, JS_START, JS_END)
bootstrap_js = f"""{JS_START}
<script src="https://code.jquery.com/jquery-3.5.1.slim.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/popper.js@1.16.1/dist/umd/popper.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/bootstrap@4.5.2/dist/js/bootstrap.min.js"></script>
{JS_END}
"""
html = re.sub(r"</body>", bootstrap_js + "\n</body>", html, count=1, flags=re.IGNORECASE)
# -----------------------------
# Save
# -----------------------------
html_path.write_text(html, encoding="utf-8")
#HIDE
%%capture
# Run this line to produce the output
!jupyter nbconvert --to html --config hide_code_config.json \
--output "docs/index.html" "master_tagged.ipynb"
### Renaming the document to produce a nicer name in web tab ###
# After nbconvert generates the HTML:
with open("docs/index.html", "r") as f:
html_content = f.read()
# Find and replace the title tag content
new_title = "Brain tumor detector" # Your desired title
html_content = html_content.replace("<title>master_tagged</title>", f"<title>{new_title}</title>")
# Insert CSS just before </head> to ensure it loads after default styles
style_snippet = """
<style>
/* Force horizontal scroll on wide code cells */
.jp-RenderedText pre,
.jp-RenderedHTMLCommon pre,
.jp-Cell .jp-Cell-inputWrapper .CodeMirror-lines,
.jp-Cell .jp-Cell-inputWrapper pre {
white-space: pre !important;
overflow-x: auto !important;
max-width: 100% !important;
display: block !important;
}
</style>
"""
# Insert style snippet at the end of <head>, or if you prefer, after <head>:
html_content = html_content.replace(
"</head>",
f"{style_snippet}\n</head>"
)
# Save the modified HTML
with open("docs/index.html", "w") as f:
f.write(html_content)